import torch

if __name__ == '__main__':
    torch.cuda.set_device(0)
    a = torch.tensor([1, 2]).cuda()
    print(a.device)
